from npu_bridge.npu_init import *
import numpy as np
import tensorflow as tf
def query_ball_point(radius, nsample, xyz, new_xyz):
    '''

    Args:
        radius: float32 local region radius
        nsample: max sample number in local region int32
        xyz: all points, [B, N, C]
        new_xyz: query points, [B, S, C]

    Returns:
        group_idx: grouped points index, [B, S, nsample]

    '''
    B = int(xyz.get_shape()[0])
    N = int(xyz.get_shape()[1])
    C = int(xyz.get_shape()[2])
    S = int(new_xyz.get_shape()[1])
    K = nsample
    group_idx = np.arange(N)
    group_idx = group_idx.reshape((1, 1, N))
    group_idx = group_idx.repeat(S, axis=1)
    group_idx = group_idx.repeat(B, axis=0)
    group_idx = tf.convert_to_tensor(group_idx, dtype=tf.int32)
    sqrdists = _square_distance(new_xyz, xyz)
    N_tensor = tf.constant(N, dtype=tf.int32, shape=[B, S, N])
    group_idx = tf.where(tf.greater(sqrdists, (radius ** 2)), N_tensor, group_idx)
    group_idx = tf.sort(group_idx, direction='ASCENDING')[:, :, :nsample]

    group_first = tf.concat([tf.reshape(group_idx[:, :, 0], [B,S,1])]*nsample,2)
    # group_first = tf.repeat(tf.reshape(group_idx[:, :, 0], [B,S,1]), nsample, axis=2)
    group_idx = tf.where(tf.equal(group_idx, N), group_first, group_idx)
    return group_idx


def group_point(points, idx):
    '''

    Args:
        points: [batch_size, ndataset, channel) float32
        idx: (batch_size, npoint, nsample) int32

    Returns:
        out: (batch_size, npoint, nsample, channel) float32
    '''
    bz = int(points.get_shape()[0])
    nd = int(points.get_shape()[1])
    channel = int(points.get_shape()[2])
    npoint = int(idx.get_shape()[1])
    sample_points = None
    if bz > 0:
        sample_points = tf.gather(points[0], idx[0], axis=0)
        sample_points = tf.expand_dims(sample_points, axis=0)

    for i in range(1, bz):
        sample_point = tf.gather(points[i], idx[i], axis=0)
        sample_point = tf.expand_dims(sample_point, axis=0)
        sample_points = tf.concat([sample_points, sample_point], 0)
    return sample_points


def knn_point(k, xyz1, xyz2):
    '''
    Input:
        k: int32, number of k in k-nn search
        xyz1: (batch_size, ndataset, c) float32 array, input points
        xyz2: (batch_size, npoint, c) float32 array, query points
    Output:
        val: (batch_size, npoint, k) float32 array, L2 distances
        idx: (batch_size, npoint, k) int32 array, indices to input points
    '''
    b = xyz1.get_shape()[0].value
    n = xyz1.get_shape()[1].value
    c = xyz1.get_shape()[2].value
    m = xyz2.get_shape()[1].value
    print (b, n, c, m)
    print (xyz1, (b,1,n,c))
    xyz1 = tf.tile(tf.reshape(xyz1, (b,1,n,c)), [1,m,1,1])
    xyz2 = tf.tile(tf.reshape(xyz2, (b,m,1,c)), [1,1,n,1])
    dist = tf.reduce_sum((xyz1-xyz2)**2, -1)
    print (dist, k)
    # outi, out = select_top_k(k, dist)
    # idx = tf.slice(outi, [0,0,0], [-1,-1,k])
    # val = tf.slice(out, [0,0,0], [-1,-1,k])
    val, idx = tf.nn.top_k(-dist, k=k) # ONLY SUPPORT CPU
    print (idx, val)

    return val, idx

def _square_distance(src, dst):
    B, N, _ = src.shape
    _, M, _ = dst.shape

    dist = -2 * tf.matmul(src, tf.transpose(dst, perm=(0, 2, 1)))
    dist += tf.reshape(tf.reduce_sum(src ** 2, -1), [B, N, 1])
    dist += tf.reshape(tf.reduce_sum(dst ** 2, -1), [B, 1, M])
    return dist

if __name__=='__main__':
    knn=True
    import numpy as np
    import time
    np.random.seed(100)
    pts = np.random.random((32,512,64)).astype('float32')
    tmp1 = np.random.random((32,512,3)).astype('float32')
    tmp2 = np.random.random((32,128,3)).astype('float32')

    points = tf.constant(pts)
    xyz1 = tf.constant(tmp1)
    xyz2 = tf.constant(tmp2)
    radius = 0.1
    nsample = 64
    if knn:
        _, idx = knn_point(nsample, xyz1, xyz2)
        grouped_points = group_point(points, idx)
    else:
        idx= query_ball_point(radius, nsample, xyz1, xyz2)
        grouped_points = group_point(points, idx)
        #grouped_points_grad = tf.ones_like(grouped_points)
        #points_grad = tf.gradients(grouped_points, points, grouped_points_grad)
    with tf.Session('') as sess:
        now = time.time()
        for _ in range(100):
            ret = sess.run(grouped_points)
        print (time.time() - now)
        print (ret.shape, ret.dtype)
        print (ret)